# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sympy as sm
from hysop.tools.numpywrappers import npw
from hysop.tools.htypes import check_instance, first_not_None, to_tuple
from hysop.symbolic import Expr, Symbol
InstructionTermination = ""
[docs]
class Select(Expr):
[docs]
def __new__(cls, a, b, c, *args):
"""Equivalent to ternary operator: (c ? b : a)"""
obj = super().__new__(cls, a, b, c, *args)
obj.a = a
obj.b = b
obj.c = c
return obj
def __str__(self):
return f"({self.c} ? {self.b} : {self.a})"
def __repr__(self):
return f"Select({repr(self.a)}, {repr(self.b)}, {repr(self.c)})"
def _sympystr(self, printer):
return "({} ? {} : {})".format(
printer._print(self.c), printer._print(self.b), printer._print(self.a)
)
[docs]
class Expand(Expr):
"""Expand a vector by a given factor.
v=v.xy => v.xxyy (factor = 2)
"""
def __new__(cls, expr, factor, *args):
obj = super().__new__(cls, expr, factor, *args)
obj.expr = expr
obj.factor = factor
return obj
def __str__(self):
return f"Expand({self.expr}, {self.factor})"
def __repr__(self):
return self.__str__()
def _sympystr(self, printer):
return f"Expand({printer._print(self.expr)}, {self.factor})"
[docs]
class BroadCast(Expr):
"""
BroadCast a vector by a given factor.
v=v.xy => v.xyxy (factor = 2)
"""
def __new__(cls, expr, factor):
obj = super().__new__(cls, expr, factor)
obj.expr = expr
obj.factor = factor
return obj
def __str__(self):
return f"BroadCast({self.expr}, {self.factor})"
def __repr__(self):
return self.__str__()
def _sympystr(self, printer):
return f"BroadCast({printer._print(self.expr)}, {self.factor})"
[docs]
class Cast(Expr):
"""
Cast a scalar or a vector to another type.
"""
def __new__(cls, expr, dtype):
obj = super().__new__(cls, expr, dtype)
obj.expr = expr
obj.dtype = dtype
return obj
def __str__(self):
return f"Cast({self.expr}, {self.dtype})"
def __repr__(self):
return self.__str__()
def _sympystr(self, printer):
return f"Cast({printer._print(self.expr)}, {self.dtype})"
[docs]
class CodeSection(Expr):
def __new__(cls, *exprs):
assert len(exprs) > 0
assert all(type(e) not in (tuple, list, set, frozenset, dict) for e in exprs)
obj = super().__new__(cls, *exprs)
return obj
def _sympystr(self, printer):
return "CodeSection([{}])".format(
"; ".join(printer._print(a) for a in self.args)
)
def _ccode(self, printer):
codegen = printer.codegen
with codegen._block_():
for e in self.args:
printer._print(e)
return InstructionTermination
[docs]
class MutexOp(Expr):
def __new__(cls, mutexes, mutex_id, *args):
obj = super().__new__(cls, mutexes, mutex_id, *args)
obj.mutexes = mutexes
obj.mutex_id = mutex_id
return obj
[docs]
class MutexLock(MutexOp):
def _sympystr(self, printer):
return "MutexLock({}, {})".format(
printer._print(self.mutexes), printer._print(self.mutex_id)
)
def _ccode(self, printer):
codegen = printer.codegen
spin = "while (atomic_cmpxchg({}+{}, 0, 1) == 1);"
spin = spin.format(printer._print(self.mutexes), printer._print(self.mutex_id))
codegen.append(spin)
return InstructionTermination
[docs]
class MutexUnlock(MutexOp):
def _sympystr(self, printer):
return "MutexUnlock({}, {})".format(
printer._print(self.mutexes), printer._print(self.mutex_id)
)
def _ccode(self, printer):
codegen = printer.codegen
spin = "{}[{}] = 0;"
spin = spin.format(printer._print(self.mutexes), printer._print(self.mutex_id))
codegen.append(spin)
return InstructionTermination
[docs]
def CriticalCodeSection(exprs, mutexes, mutex_id, *args):
from hysop.symbolic.relational import Assignment
exprs = to_tuple(exprs)
mutex_lock = MutexLock(mutexes, mutex_id)
mutex_unlock = MutexUnlock(mutexes, mutex_id)
exprs = (mutex_lock,) + exprs + (mutex_unlock,)
return CodeSection(*exprs)
[docs]
class ApplyStencil(Expr):
def __new__(cls, expr, stencil):
obj = super().__new__(cls, expr, stencil)
obj.expr = expr
obj.stencil = stencil
return obj
def _sympystr(self, printer):
return f"ApplyStencil({printer._print(self.expr)}, {self.stencil.coeffs})"
[docs]
class TimeIntegrate(Expr):
"""
Represents variable time integration for code generation.
Parameters
----------
time_integrator: TimeIntegrator
lhs : Field
rhs : Expr
"""
def __new__(cls, time_integrator, lhs, rhs):
obj = super().__new__(cls, time_integrator, lhs, rhs)
obj.time_integrator = time_integrator
obj.lhs = lhs
obj.rhs = rhs
return obj
def __str__(self):
return f"{self.lhs} = {self.time_integrator.name()}({self.rhs})"
def _ccode(self):
raise RuntimeError()